0.1 Re-classifying cycles for Birth Control

In 00_cycle_classification_BC_expl.Rdm, we have explored the best way to re-classify cycles according to the birth control method (pill vs none/condoms).

We identified input variables that were useful to classify cycles and tested several methods to re-classify these cycles.

The best strategy seems to be

  1. train a SVM on the original birth_control labels to obtain probabilities

  2. define (and not train) a HMM to add memory in the assignation and use the viterbi algorithm to obtain the most likely BC.

  3. define a confidence score for that new assignation that reflect the assymetry in the probability distribution (from 1.)

  4. select users and cycles based on that confidence score.

0.1.1 Loading data

Loading cycles

file.copy(paste0(IO$output_data, "cycles.Rdata"), paste0(IO$tmp_data, "cycles_before_BC_re_classification.Rdata"))
## [1] FALSE
load(paste0(IO$output_data, "cycles.Rdata"), verbose = TRUE)
## Loading objects:
##   cycles
cycles$birth_control_CLUE = factor(cycles$birth_control_CLUE, levels = c("none / condoms","pill","did not enter"))
cycles$user_id_n = as.numeric(factor(cycles$user_id))

0.1.2 Input variables

inputs = c("cycle_length","n_days_obs","n_pill","n_prot_sex","n_unprot_sex","diff_cl_median_3c","cl_sd_3c","cycle_nb","period_length","n_egg_white_fluid", "cycle_start")
# ,"score_n_days_obs","score_BC","score_CL"
inputs
##  [1] "cycle_length"      "n_days_obs"        "n_pill"           
##  [4] "n_prot_sex"        "n_unprot_sex"      "diff_cl_median_3c"
##  [7] "cl_sd_3c"          "cycle_nb"          "period_length"    
## [10] "n_egg_white_fluid" "cycle_start"
# ,"score_n_days_obs","score_BC","score_CL"
for(input in inputs[-which(inputs == "cycle_start")]){
  eval(parse(text = paste0("cycles$v = cycles$",input)))
  g =  ggplot(cycles, aes(x = v, col = birth_control_CLUE)) + geom_freqpoly(aes(y = ..density..),binwidth = 1) +
    ggtitle(input) +
    scale_color_manual(values = c(cols$BC,"gray")) + 
    xlim(quantile(cycles$v,p = c(0.005,0.955), na.rm = TRUE)+c(-1,1))
  print(g)
}
## Warning: Removed 15511 rows containing non-finite values (stat_bin).
## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 14559 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 13325 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 12584 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 13401 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 18255 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 16984 rows containing non-finite values (stat_bin).
## Warning: Removed 9 rows containing missing values (geom_path).

## Warning: Removed 12949 rows containing non-finite values (stat_bin).
## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 8408 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

## Warning: Removed 13055 rows containing non-finite values (stat_bin).

## Warning: Removed 6 rows containing missing values (geom_path).

j = which(colnames(cycles)=="v")
if(length(j)>0){cycles = cycles[,-j]}

0.1.3 Logistic Regression for interpretable significance of the input variables

tic()
eval(parse(text = paste0("
                         glm_fit = glm(birth_control_CLUE ~ ",paste(inputs, collapse = " + "),", 
                         data = cycles[cycles$birth_control_CLUE != 'did not enter',], 
                         family = 'binomial' )")))
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
toc() # takes about 20 seconds on the full dataset
## 2.144 sec elapsed
summary(glm_fit)
## 
## Call:
## glm(formula = birth_control_CLUE ~ cycle_length + n_days_obs + 
##     n_pill + n_prot_sex + n_unprot_sex + diff_cl_median_3c + 
##     cl_sd_3c + cycle_nb + period_length + n_egg_white_fluid + 
##     cycle_start, family = "binomial", data = cycles[cycles$birth_control_CLUE != 
##     "did not enter", ])
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -8.4904  -0.9224   0.2727   0.9297   4.9644  
## 
## Coefficients:
##                      Estimate  Std. Error z value             Pr(>|z|)    
## (Intercept)       22.12138438  0.33129128  66.773 < 0.0000000000000002 ***
## cycle_length      -0.00778446  0.00051539 -15.104 < 0.0000000000000002 ***
## n_days_obs        -0.07079176  0.00077795 -90.998 < 0.0000000000000002 ***
## n_pill             0.18032950  0.00085931 209.854 < 0.0000000000000002 ***
## n_prot_sex         0.01262775  0.00283052   4.461 0.000008146859609551 ***
## n_unprot_sex      -0.03320169  0.00231321 -14.353 < 0.0000000000000002 ***
## diff_cl_median_3c  0.00767992  0.00052990  14.493 < 0.0000000000000002 ***
## cl_sd_3c           0.00346777  0.00034043  10.186 < 0.0000000000000002 ***
## cycle_nb           0.00049383  0.00065469   0.754              0.45067    
## period_length     -0.00774058  0.00257478  -3.006              0.00264 ** 
## n_egg_white_fluid -0.02267326  0.00278475  -8.142 0.000000000000000389 ***
## cycle_start       -0.00126213  0.00001951 -64.681 < 0.0000000000000002 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 357802  on 261103  degrees of freedom
## Residual deviance: 267600  on 261092  degrees of freedom
## AIC: 267624
## 
## Number of Fisher Scoring iterations: 6

0.1.4 Classifying cycles using a SVM

subset_perc = 5
sample_size = round(subset_perc*nrow(cycles)/100)
sample_size
## [1] 19539

We will only use a subset (5%) of the data to train the model.

This represents 19539 cycles.

And we only train on cycles of users that indicated pill, none or condoms and exclude users that did not enter their birth control in the settings

j = sample( which(cycles$birth_control_CLUE %in% c("pill", "none / condoms")), sample_size)
cycles_training_set = cycles[j , ]

# fit
tic()
eval(parse(text = paste0("svm_fit = parallelSVM(birth_control_CLUE ~ ",paste(inputs, collapse = " + "),", data = cycles_training_set , probability = TRUE)")))
toc()
## 15.498 sec elapsed
summary(svm_fit)
## 
## Call:
## parallelSVM(formula = birth_control_CLUE ~ cycle_length +  n_days_obs + n_pill + n_prot_sex + n_unprot_sex + diff_cl_median_3c +  cl_sd_3c + cycle_nb + period_length + n_egg_white_fluid +  cycle_start, data = cycles_training_set, probability = TRUE)
## 
## 
## Parameters: 
##    SVM-Type:  C-classficiation
##  SVM-Kernel:  radial
##        cost:  1
##       gamma:  0.09090909
## 
## Average Number of Support Vectors: 2366
##  
## 
##  ( 1183 1183 )
## 
## 
## Number of classes: 2
## 
## Levels:
## none / condoms pill did not enter
## 
## 
## 
# prediction
cycles$SVM_prob = NA
j = 1:nrow(cycles)
#j = which(cycles$user_id %in% unique(cycles$user_id)[1:1000])

tic()
SVM_prob = predict(svm_fit, newdata = cycles[j,inputs], decision.value = TRUE, probability = TRUE)
toc()
## 273.673 sec elapsed
SVM_prob_p = attr(SVM_prob, "probabilities")
head(SVM_prob_p, 20)
##    none / condoms      pill
## 1      0.31827050 0.6817295
## 2      0.20848214 0.7915179
## 3      0.10631605 0.8936839
## 4      0.12702390 0.8729761
## 5      0.15394771 0.8460523
## 6      0.22664275 0.7733572
## 7      0.20503387 0.7949661
## 8      0.19596774 0.8040323
## 9      0.35874648 0.6412535
## 10     0.70846520 0.2915348
## 11     0.14809252 0.8519075
## 12     0.47909448 0.5209055
## 13     0.77996219 0.2200378
## 14     0.76825920 0.2317408
## 15     0.79478856 0.2052114
## 16     0.25071617 0.7492838
## 17     0.09185932 0.9081407
## 18     0.12548530 0.8745147
## 19     0.10561323 0.8943868
## 20     0.11269646 0.8873035
cycles$SVM_prob[j] = SVM_prob_p[,which(colnames(SVM_prob_p) == "pill")]
#comparison
table_cont = table(cycles$SVM_prob<0.5, cycles$birth_control_CLUE)

ggplot(cycles, aes(x =SVM_prob, fill = birth_control_CLUE )) + geom_density(alpha = 0.5, bw = 0.01, col = NA) + geom_vline(xintercept = 0.5, linetype = 2)+ xlim(c(0,1))+ scale_fill_manual(values = cols$BC3) + facet_grid(birth_control_CLUE ~ .)

table_cont 
##        
##         none / condoms   pill did not enter
##   FALSE          30564 113361         47071
##   TRUE           83521  33658         82596
round(t(t(table_cont)/apply(table_cont, 2, sum)), digits = 2)
##        
##         none / condoms pill did not enter
##   FALSE           0.27 0.77          0.36
##   TRUE            0.73 0.23          0.64
rm(SVM_prob, table_cont, svm_fit)
save(cycles, file = paste0(IO$tmp_data, "cycles_SVM.Rdata"))
sub_cycles = cycles[cycles$user_id %in% unique(cycles$user_id)[1:80],]
sub_cycles$user_id_n = factor(sub_cycles$user_id_n, levels = unique(sub_cycles$user_id_n[order(sub_cycles$birth_control_CLUE)]))

g_svm = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = SVM_prob)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+ scale_color_gradient2(low = cols$NC, high = cols$pill, mid = "gray90", midpoint = 0.5)+theme(legend.position="top")

grid.arrange(g_svm, nrow = 1)

rm(sub_cy,g_svm)
## Warning in rm(sub_cy, g_svm): object 'sub_cy' not found

0.1.5 Adding memory with an HMM

Model initialisation:

hmm = list()

# defining model
hmm$bc_states = c("none / condoms","pill","X-nc", "X-pill","X-did-not-enter") # we split the state X into 3 sub-states to account for the on-boarding BC info of a given user
hmm$symbols = c(paste0("p_",0:10),"X-nc","X-pill","X-did-not-enter")

#initializing model

bc_emission_prob.N = c(dbeta(seq(0,1,by = 0.1), shape1 = 1 , shape2 = 3)+0.2,0,0,0) # more likely to observe LOW score for N/C
bc_emission_prob.P = c(dbeta(seq(0,1,by = 0.1), shape1 = 3 , shape2 = 1)+0.2,0,0,0) # more likely to observe HIGH score for pill
bc_emission_prob.Xnc = c(rep(0,11),1,0,0) # we only observe Xs when we switch users
bc_emission_prob.Xpill = c(rep(0,11),0,1,0) # we only observe Xs when we switch users
bc_emission_prob.Xdne = c(rep(0,11),0,0,1) # we only observe Xs when we switch users

plot(seq(0,1,by = 0.1) , bc_emission_prob.N[1:11], type = "l", col = cols$NC, lwd = 2)
points(seq(0,1,by = 0.1) , bc_emission_prob.P[1:11], type = "l", col = cols$pill, lwd = 2)

# building and normalizing emission prob matrix
hmm$bc_emission_prob = matrix(c(bc_emission_prob.N,bc_emission_prob.P, bc_emission_prob.Xnc, bc_emission_prob.Xpill,bc_emission_prob.Xdne),nrow = length(hmm$bc_states), ncol = length(hmm$symbols), byrow = TRUE)
hmm$bc_emission_prob = hmm$bc_emission_prob/apply(hmm$bc_emission_prob, 1, sum)

rm(bc_emission_prob.N,bc_emission_prob.P,bc_emission_prob.Xnc, bc_emission_prob.Xpill)


hmm$start_prob_nc = 0.6  #
hmm$bc_transition_prob = matrix(
  c(0.9,0.08,0.01,0.01,0.01,
    0.08,0.9,0.01,0.01,0.01,
    hmm$start_prob_nc,1-hmm$start_prob_nc,0,0,0,
    1-hmm$start_prob_nc,hmm$start_prob_nc,0,0,0,
    0.5, 0.5, 0,0,0), 
  nrow = length(hmm$bc_states), ncol = length(hmm$bc_states), byrow = TRUE)

hmm$start_prob = c(hmm$start_prob_nc,1-hmm$start_prob_nc, 0.5)[as.numeric(cycles$birth_control_CLUE[1])]
hmm$start_prob = c(hmm$start_prob, 1 - hmm$start_prob, 0,0,0)

hmm_mod = initHMM(
  States =  hmm$bc_states, Symbols =  hmm$symbols, 
  startProbs =  hmm$start_prob, transProbs =  hmm$bc_transition_prob, 
  emissionProbs =  hmm$bc_emission_prob)

Observations:

#cy = cy[order(cy$user_id, cy$cycle_nb),]

obs_p =  paste0("p_",round(10*cycles$SVM_prob))

# inserting Xs
x = which(diff(cycles$user_id_n) !=0) 
obs_x = c("X-nc","X-pill","X-did-not-enter")[as.numeric(cycles$birth_control_CLUE)]
id = c(seq_along(obs_p)*10, 10*x+5)
o = order(id)
obs = c(obs_p, obs_x)[o]

rm(x, id, o, obs_p, obs_x)

Estimating most likely BC for each cycle using the Viterbi algorithm:

# running viterbi
tic()
vit_init = viterbi(hmm = hmm_mod, obs = obs) 
toc()
## 31.119 sec elapsed
vit_init_no_x = vit_init[-grep("X-",vit_init)]
cycles$BC_hmm_init_SVM = vit_init_no_x

rm(vit_init, vit_init_no_x, hmm_mod, obs)
table_cont_init_bc = table(cycles$BC_hmm_init_SVM, cycles$birth_control_CLUE)
round(t(t(table_cont_init_bc)/apply(table_cont_init_bc, 2, sum)), digits = 2)
##                 
##                  none / condoms pill did not enter
##   none / condoms           0.80 0.17          0.67
##   pill                     0.20 0.83          0.33
save(cycles, file = paste0(IO$tmp_data, "cycles_SVM_HMM.Rdata"))
sub_cycles = cycles[cycles$user_id %in% unique(cycles$user_id)[1:80],]
sub_cycles$user_id_n = factor(sub_cycles$user_id_n, levels = unique(sub_cycles$user_id_n[order(sub_cycles$birth_control_CLUE)]))


g_vit_init = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC_hmm_init_SVM)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC)
g_svm = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = SVM_prob)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+ scale_color_gradient2(low = cols$NC, high = cols$pill, mid = "gray90", midpoint = 0.5)+ theme(legend.position="top")

grid.arrange(g_vit_init, g_svm, nrow = 1)

rm(sub_cy, g_svm, g_vit_init)
## Warning in rm(sub_cy, g_svm, g_vit_init): object 'sub_cy' not found

0.1.6 Confidence Score

We define a confidence score that accounts for the assymetry in the probability distribution for pill vs N/C cycles. The confidence score is defined as the cumulative density function of the SVM probabilities for the cycles that stayed in the same BC class.

by = 0.001
h_pill = hist(cycles$SVM_prob[(cycles$birth_control_CLUE == "pill") & (cycles$BC_hmm_init_SVM == "pill")], breaks = seq(0-by/2,1+by/2,by = by), plot = FALSE)
h_nc = hist(1-cycles$SVM_prob[(cycles$birth_control_CLUE != "pill") & (cycles$BC_hmm_init_SVM != "pill")], breaks = seq(0-by/2,1+by/2,by = by), plot = FALSE)

h_pill = data.frame(mids = h_pill$mids, density = h_pill$density, cumsum = cumsum(h_pill$density)/max(cumsum(h_pill$density)))
h_nc = data.frame(mids = h_nc$mids, density = h_nc$density, cumsum = cumsum(h_nc$density)/max(cumsum(h_nc$density)))


confidence_thres = 0.1


g_pill = ggplot(h_pill, aes(x = mids, y = cumsum) )+ 
  geom_line()+ geom_point() + 
  geom_line(aes(y = density/10)) + 
  geom_hline(yintercept = confidence_thres)+
  xlab("SVM prob") + ylab("confidence score")+
  ggtitle("pill confidence score")

g_nc = ggplot(h_nc, aes(x = mids, y = cumsum) )+ 
  geom_line()+ geom_point() + 
  geom_line(aes(y = density/10))+ 
  geom_hline(yintercept = confidence_thres)+
  xlab("SVM prob") + ylab("confidence score")+
  ggtitle("none/condoms confidence score")

grid.arrange(g_pill, g_nc)

cycles$conf = 0
cycles$conf[cycles$BC_hmm_init_SVM == "pill"] = 
  h_pill$cumsum[match(round(cycles$SVM_prob[cycles$BC_hmm_init_SVM == "pill"],digits = 3),round(h_pill$mids, digits = 3))]
cycles$conf[cycles$BC_hmm_init_SVM != "pill"] = 
  h_nc$cumsum[match(round(1-cycles$SVM_prob[cycles$BC_hmm_init_SVM != "pill"],digits = 3),round(h_nc$mids, digits = 3))]


g = ggplot(cycles, aes(x = conf, fill = birth_control_CLUE))
g = g  + geom_histogram(alpha = 0.5, position = "identity", binwidth = 0.05) + 
  facet_grid( BC_hmm_init_SVM ~ . ) + 
  ggtitle("Distribution of cycles confidence score")
g + geom_vline(xintercept = confidence_thres, size = 0.3, linetype = 2)

0.1.7 New BC labels

  • we take all users for which the new label is the same as the original label
  • for users that have cycles with different labels, we only take users for which all cycles have high enough confidence score for the new labels and for which the average confidence over that sequence of consecutive cycle is high enough as well (twice higher the confidence score for individual cycles).
cycles$BC = NA


# all cycles of a given user has a confidence score that is high enough
conf = (cycles$conf >= confidence_thres)
conf_u = aggregate(conf, by = list(user_id = cycles$user_id), FUN = all)
user_ids_all_cycles_above_threshold = conf_u$user_id[conf_u$x]

# average confidence score per user
conf_av_u = aggregate(cycles$conf, by = list(user_id = cycles$user_id), FUN = mean)
hist(conf_av_u$x, breaks = 100)

user_ids_average_conf_score = conf_av_u$user_id[conf_av_u$x >= 2*confidence_thres]

# keeping the cycles from users that had all their cycles matching the original BC
same = (cycles$birth_control_CLUE == cycles$BC_hmm_init_SVM)
same_u = aggregate(same, by = list(user_id = cycles$user_id), FUN = all)
user_ids_same_label = same_u$user_id[same_u$x]


user_ids = union(intersect(user_ids_all_cycles_above_threshold,user_ids_average_conf_score), user_ids_same_label)
length(user_ids)
## [1] 13300
cycles$BC[cycles$user_id %in% user_ids] = cycles$BC_hmm_init_SVM[cycles$user_id %in% user_ids]
cycles$BC[is.na(cycles$BC)] = "unclear"

table(cycles$BC)
## 
## none / condoms           pill        unclear 
##         102028         112972         175771
table(cycles$BC, cycles$birth_control_CLUE)
##                 
##                  none / condoms   pill did not enter
##   none / condoms          72760   3481         25787
##   pill                     2539 100742          9691
##   unclear                 38786  42796         94189
save(cycles, file = paste0(IO$tmp_Rdata,"cycles_BC_re_classification_final.Rdata"))
sub_cycles = cycles[cycles$user_id %in% unique(cycles$user_id)[1:80],]
sub_cycles$user_id_n = factor(sub_cycles$user_id_n, levels = unique(sub_cycles$user_id_n[order(sub_cycles$birth_control_CLUE)]))


g_vit_init = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC_hmm_init_SVM)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC)
g_svm = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = SVM_prob)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+ scale_color_gradient2(low = cols$NC, high = cols$pill, mid = "gray90", midpoint = 0.5)+ theme(legend.position="top")

grid.arrange(g_vit_init, g_svm, nrow = 1)

rm(sub_cycles, g_svm, g_vit_init)



sub_cycles = cycles[cycles$user_id %in% unique(cycles$user_id)[1:80],]
sub_cycles$user_id_n = factor(sub_cycles$user_id_n, levels = unique(sub_cycles$user_id_n[order(sub_cycles$birth_control_CLUE)]))
sub_cycles$BC = factor(sub_cycles$BC, levels = c(as.character(par$BC_dict$name),"unclear"))

g_BC = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC3)

g_vit_init = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC_hmm_init_SVM)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC)

g_svm = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = SVM_prob)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+ scale_color_gradient2(low = cols$NC, high = cols$pill, mid = "gray90", midpoint = 0.5)+ theme(legend.position="top")

g_conf = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = conf)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+
  scale_color_gradientn(colours = c("red","gray90","black"), 
                         values = rescale(c(0,confidence_thres,1)),
                         limits=c(0,1))+
#scale_color_gradient2(low = "red", high = "black",mid = "gray90", midpoint = confidence_thres, limits = c(0,2*confidence_thres))+
  theme(legend.position="top")

grid.arrange(g_BC, g_vit_init, g_svm, g_conf, nrow = 1)

rm(sub_cycles, g_BC, g_svm, g_vit_init, g_conf)
agg = aggregate(BC ~ user_id_n, cycles[cycles$BC != "unclear",], lu)
sum(agg$BC > 1)
## [1] 801
sub_cycles = cycles[cycles$user_id_n %in% agg$user_id_n[which(agg$BC>1)[1:80]],]
sub_cycles$user_id_n = factor(sub_cycles$user_id_n, levels = unique(sub_cycles$user_id_n[order(sub_cycles$birth_control_CLUE)]))
sub_cycles$BC = factor(sub_cycles$BC, levels = c(as.character(par$BC_dict$name),"unclear"))

g_BC = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC3)

g_vit_init = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC_hmm_init_SVM)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC)

g_svm = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = SVM_prob)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+ scale_color_gradient2(low = cols$NC, high = cols$pill, mid = "gray90", midpoint = 0.5)+ theme(legend.position="top")

g_conf = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = conf)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+
  scale_color_gradientn(colours = c("red","gray90","black"), 
                         values = rescale(c(0,confidence_thres,1)),
                         limits=c(0,1))+
#scale_color_gradient2(low = "red", high = "black",mid = "gray90", midpoint = confidence_thres, limits = c(0,2*confidence_thres))+
  theme(legend.position="top")

grid.arrange(g_BC, g_vit_init, g_svm, g_conf, nrow = 1)

rm(sub_cycles, g_BC, g_svm, g_vit_init, g_conf)


sub_cycles = cycles[cycles$user_id_n %in% agg$user_id_n[which(agg$BC==1)[1:80]],]
sub_cycles$user_id_n = factor(sub_cycles$user_id_n, levels = unique(sub_cycles$user_id_n[order(sub_cycles$birth_control_CLUE)]))
sub_cycles$BC = factor(sub_cycles$BC, levels = c(as.character(par$BC_dict$name),"unclear"))

g_BC = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC3)

g_vit_init = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = BC_hmm_init_SVM)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y") + theme(legend.position="top")+ scale_color_manual(values = cols$BC)

g_svm = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = SVM_prob)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+ scale_color_gradient2(low = cols$NC, high = cols$pill, mid = "gray90", midpoint = 0.5)+ theme(legend.position="top")

g_conf = ggplot(sub_cycles, aes(x = cycle_nb, y = user_id_n, col = conf)) + geom_point() + facet_grid(birth_control_CLUE ~ ., scale = "free_y")+
  scale_color_gradientn(colours = c("red","gray90","black"), 
                         values = rescale(c(0,confidence_thres,1)),
                         limits=c(0,1))+
#scale_color_gradient2(low = "red", high = "black",mid = "gray90", midpoint = confidence_thres, limits = c(0,2*confidence_thres))+
  theme(legend.position="top")

grid.arrange(g_BC, g_vit_init, g_svm, g_conf, nrow = 1)

rm(sub_cycles, g_BC, g_svm, g_vit_init, g_conf)
rm(agg)